import pickle
import os
import numpy as np
from python_ai.common.xcommon import *

sep('Load data')
BASE_DIR, FILE_NAME = os.path.split(__file__)
save_path = os.path.join(BASE_DIR, '_data', 'mnist_dataset_gen.py.pkl')
if not os.path.exists(save_path):
    raise Exception(f'{save_path} does not exist. Please run mnist_dataset_gen.py to generate it.')
print(f'Loading data from {save_path}')
with open(save_path, 'rb') as f:
    data_dict = pickle.load(f)
x_train = data_dict['x_train']
x_test = data_dict['x_test']
x_val = data_dict['x_val']
y_train = data_dict['y_train']
y_test = data_dict['y_test']
y_val = data_dict['y_val']
shape_ = data_dict['shape_']

print('x_train', x_train.shape)
print('x_val', x_val.shape)
print('x_test', x_test.shape)
print('y_train', y_train.shape)
print('y_val', y_val.shape)
print('y_test', y_test.shape)
